"""
This module is used to generate Onnx models for testing operator support, and to validate the generated models to ensure they produce spec compliant results.
It's expected that the generated models nodes will be chained Ops of the same type or constants.
Inputs are stored to validate the model results, to debug issues with generating test models, and to potentially generate deserializable test data for the model
"""
from pathlib import Path
from typing import List, Optional, Tuple, NewType, TypeAlias
from numpy.typing import ArrayLike, NDArray
import onnx
import onnxruntime
import numpy as np
from dataclasses import dataclass, field, InitVar

# TODO: need to come up with some examples of valid inputs
TensorMap: TypeAlias = (
    dict[str, ArrayLike] | List[ArrayLike] | str | Tuple[str, ArrayLike]
)
"""TypeAlias for the inputs to OnnxGraphBuilder. if it's a dictionary, the keys are the names of the inputs and the values are the input data. 
If it's a list, for non tuple elements, names are autogenerated. For tuple elements, the first element is the name of the input and the second element is the value of the input.

Note:

The ">" char can be used to indicate the output of the previous node if the previous node has a single output. Example:

[">", np.array([1,2,3])]

"""

NodeOutput = NewType("NodeOutput", str)


def validate_sequence(arr: List | Tuple) -> NDArray:
    """Function to validate that all elements in a sequence are of the same type.


    Args:
        arr (List | Tuple): Sequence to validate

    Raises:
        ValueError: raised if the elements in the sequence are not of the same type.
    """
    el_type = type(arr[0])
    for el in arr:
        if type(el) != el_type:
            raise ValueError(
                f"Expected all elements to be of type {el_type} but got {type(el)}"
            )
    return np.array(arr)


var_counter = 0
op_counter = 0


def validate_input(
    input_data: TensorMap,
    value_names: set[NodeOutput],
    out_names: set[str],
) -> dict[str, ArrayLike | NodeOutput]:
    """helper function to validate inputs to a node

    Args:
        input_data (dict[str, ArrayLike  |  str]): The inputs to the node that need to be validated
        value_names (set[NodeOutput]): the list of names of outputs from nodes already in the graph
        out_names (set[str]): set of names of onnx values produced by the Node, used to make sure that the input names don't conflict with the output names.

    Raises:
        ValueError: _description_
        TypeError: _description_

    Returns:
        dict[str, ArrayLike | NodeOutput]: _description_
    """

    def kv_check(value_names, out_names, res, var_name, var_value):
        if var_name in out_names:
            raise ValueError(
                f"Tensor Input {var_name} cannot have the same name as a tensors output"
            )
        if isinstance(var_value, str):
            if var_value not in value_names:
                raise TypeError(
                    f"NodeOutput {var_value} not found in outputs. Please provide the output of a previous node as an input."
                )
            res[var_name] = NodeOutput(var_value)
        else:
            res[var_name] = var_value

    def next_var():
        global var_counter
        var_counter += 1
        return f"var_{var_counter}"

    res: dict[str, ArrayLike | NodeOutput] = {}
    match input_data:
        case dict():
            for k, v in input_data.items():
                kv_check(value_names, out_names, res, k, v)
        case list():
            for item in input_data:
                if isinstance(item, str):
                    if item == ">":
                        res[item] = NodeOutput(item)
                    if item not in value_names:
                        raise TypeError(
                            f"NodeOutput {item} not found in outputs. Please provide the output of a previous node as an input."
                        )
                elif isinstance(item, tuple):
                    # until I come up with something better
                    if isinstance(item[0], str):
                        (
                            kv_check(
                                value_names,
                                out_names,
                                res,
                                next_var(),
                                item[1],
                            ),
                        )

                    else:  # it's an arraylike which will be validated later
                        res[next_var()] = item
                elif isinstance(item, np.ndarray | list) or np.isscalar(item):
                    res[f"var_{var_counter}"] = item

    return res


def _get_tensor(name: str, arr: np.ndarray):
    tensor_type = onnx.helper.np_dtype_to_tensor_dtype(arr.dtype)
    return onnx.helper.make_tensor_value_info(name, tensor_type, arr.shape)


def _get_scalar(name: str, scalar_type: int):
    return onnx.helper.make_value_info(name, scalar_type)  # type: ignore


def make_onnx_types(name: str, input_data: ArrayLike) -> onnx.ValueInfoProto:
    """Function to map inputs to OnnxOpData to onnx types

    Args:
        name (str): The name of the input
        v (ArrayLike): The input data

    Raises:
        ValueError: If the input data is not a supported type (np.ndarray, list, tuple, int, float, bool) then an error is raised.

    Returns:
        onnx.TensorProto | onnx.ValueInfoProto: returns a tensor proto or value info proto based on the input data.
    """
    match input_data:
        case np.ndarray():
            return _get_tensor(name, input_data)  # type: ignore
        case list() | tuple():
            return _get_tensor(name, validate_sequence(input_data))  # type: ignore
        case int():
            return _get_scalar(name, onnx.ValueInfoProto.INT64)
        case float():
            return _get_scalar(name, onnx.ValueInfoProto.FLOAT)
        case bool():
            return _get_scalar(name, onnx.ValueInfoProto.BOOL)
        case _:
            raise ValueError(f"Unsupported type: {type(input_data)}")


@dataclass
class OnnxConst:
    name: str
    value: InitVar[ArrayLike]
    # __value: NDArray = field(init=False, default_factory=np.array)  # type: ignore
    __tensor: onnx.TensorProto = field(init=False, default_factory=onnx.TensorProto)  # type: ignore

    def __post_init__(self, value):
        if np.isscalar(value) or isinstance(value, (list, tuple)):
            value = np.array(value)
        self.__tensor = onnx.helper.make_tensor(
            name=self.name,
            data_type=onnx.helper.np_dtype_to_tensor_dtype(value.dtype),
            dims=value.shape,
            vals=value.flatten(),
        )

    def to_onnx(self):
        return onnx.helper.make_node(
            "Constant",
            inputs=[],
            outputs=[self.name],
            value=self.__tensor,
        )

    def to_ndarray(self):
        return np.frombuffer(
            self.__tensor.raw_data,
            onnx.helper.tensor_dtype_to_np_dtype(self.__tensor.data_type),
        ).reshape(self.__tensor.dims)


@dataclass
class OnnxOpData:
    """helper for generating and validating nodes for testing operator support.

    Attributes:
        name (str): The name of the operator. Must match the name of the operator in onnx.
        inputs (dict[str, ArrayLike]): The inputs to the operator
        output (dict[str, ArrayLike]): The expected output of the operator
    """

    op_name: str
    inputs: dict[str, ArrayLike | NodeOutput]
    output: dict[str, ArrayLike]
    count: int = field(init=False)

    def __post_init__(self):
        global op_counter
        op_counter += 1
        self.count = op_counter

    @property
    def input_names(self) -> List[str]:
        return list(self.inputs.keys())

    @property
    def output_names(self) -> List[str]:
        return list(self.output.keys())

    @property
    def output_vals(self) -> List[onnx.ValueInfoProto]:
        return [make_onnx_types(k, v) for k, v in self.output.items()]

    @property
    def input_vals(self) -> List[onnx.TensorProto | onnx.ValueInfoProto]:
        return [
            make_onnx_types(k, v)
            for k, v in self.inputs.items()
            if type(v) != NodeOutput
        ]

    def to_onnx(self):
        return onnx.helper.make_node(
            self.op_name,
            inputs=self.input_names,
            outputs=self.output_names,
            name=f"{self.op_name}{self.count}",
        )


def _get_path(
    graph_name: str, path: Optional[Path | str] = None, ext: str = ".onnx"
) -> str:
    out_path = Path(".") / f"{graph_name.lower()}{ext}"
    if path:
        if (tmp := Path(path)).is_dir():
            out_path = tmp / f"{graph_name.lower()}{ext}"
        elif tmp.suffix != ".onnx":
            raise ValueError(
                f"Provide path {path} must include the model name and end with .onnx extension"
            )

    return str(out_path)


@dataclass
class OnnxGraphBuilder:
    name: InitVar[str]
    inputs: InitVar[dict[str, ArrayLike]]
    output: InitVar[dict[str, ArrayLike]]
    rhs_constant: InitVar[bool] = field(default=False)

    graph_name: str = field(init=False)
    value_map: dict[str, ArrayLike] = field(init=False)
    node_counter: int = field(init=False, default=0)
    output_set: set[NodeOutput] = field(init=False, default_factory=set)
    nodes: List[OnnxOpData] = field(init=False, default_factory=list)
    constants: dict[str, OnnxConst] = field(default_factory=dict)

    def __post_init__(self, name, inputs, output, rhs_constant):
        self.graph_name = f"{name}"
        if rhs_constant:
            self.first_idx = 1
        out_names = set(output.keys())
        validated_input = validate_input(inputs, self.output_set, out_names)

        for k in output:
            out_name = NodeOutput(k)
            if out_name in self.output_set:
                raise ValueError(f"Output {k} already exists in the graph")
            self.output_set.add(out_name)

        if rhs_constant:
            const_name = list(validated_input.keys())[1]
            const_val = validated_input[const_name]
            self.constants[const_name] = OnnxConst(const_name, const_val)
            const_out = NodeOutput(const_name)
            if const_out in self.output_set:
                raise ValueError(
                    f"Name for rhs const {const_name}  already exists in the graph"
                )
            self.output_set.add(const_out)

        self.nodes.append(OnnxOpData(name, inputs, output))

    def add_node(
        self,
        name: str,
        inputs: dict[str, ArrayLike],
        output: dict[str, ArrayLike],
        rhs_constant: bool = False,
    ):
        out_names = set(output.keys())
        validated_input = validate_input(inputs, self.output_set, out_names)
        if ">" in validated_input:
            # only works if the previous node has a single output
            if len((prev_out := self.nodes[-1].output_names)) != 1:
                raise KeyError(
                    "Previous node has more than one output. Please specify the output to use"
                )
            validated_input[prev_out[0]] = NodeOutput(prev_out[0])

        for k in output:
            out_name = NodeOutput(k)
            if out_name in self.output_set:
                raise ValueError(f"Output {k} already exists in the graph")
            self.output_set.add(out_name)

        if rhs_constant:
            const_name = list(validated_input.keys())[1]
            const_val = validated_input[const_name]
            self.constants[const_name] = OnnxConst(const_name, const_val)
            const_out = NodeOutput(const_name)
            if const_out in self.output_set:
                raise ValueError(
                    f"Name for rhs const {const_name}  already exists in the graph"
                )
            self.output_set.add(const_out)

        self.nodes.append(OnnxOpData(name, inputs, output))

    @property
    def graph_nodes(self):
        res = [const.to_onnx() for const in self.constants.values()]
        res.extend([node.to_onnx() for node in self.nodes])
        return res

    def get_graph_inputs(self) -> List[onnx.ValueInfoProto]:
        res: List[onnx.ValueInfoProto] = []
        for node in self.nodes:
            res.extend(
                make_onnx_types(k, v)
                for k, v in node.inputs.items()
                if k not in self.constants and NodeOutput(k) not in self.output_set
            )
        return res

    def get_graph_outputs(self):
        return self.nodes[-1].output_vals

    def get_output_names(self):
        return self.nodes[-1].output_names

    def get_expected_outputs(self):
        return list(self.nodes[-1].output.values())

    def make_onnx_graph(self) -> onnx.GraphProto:
        """Create a graph with a single node for testing.

        Args:
            op_inputs: The input tensor to the node."""

        graph: onnx.GraphProto = onnx.helper.make_graph(
            self.graph_nodes,
            self.graph_name,
            inputs=self.get_graph_inputs(),  # type: ignore
            outputs=self.get_graph_outputs(),  # type: ignore
        )

        return graph

    def save_model(self, path: Optional[Path | str] = None):
        """Converts the generated graph to an onnx model and saves it to a file.

        Args:
            path (Optional[Path  |  str], optional): desired path to the output. if unspecified, defaults to {op_name}.onnx

        Raises:
            ValueError: If you provide a path and it doesn't end with .onnx then an error is raised
        """
        model = onnx.helper.make_model(self.make_onnx_graph())
        out_path = _get_path(self.graph_name, path)
        onnx.save(model, out_path)
        print(f"Model saved to {out_path}")

    def get_sess_inputs(self):
        res = {}
        for node in self.nodes:
            for k, v in node.inputs.items():
                if k not in self.constants and NodeOutput(k) not in self.output_set:
                    res[k] = v
        return res

    def validate_model(self, model_path: Optional[Path | str] = None):
        """Loads the generated model and runs it with the provided inputs to validate the output.
        More of a sanity check than anything else.

        Returns:
            Outputs (Any): returns the outputs of the model in case there is a need to inspect them.
        """
        model_path = _get_path(self.graph_name, model_path)
        sess = onnxruntime.InferenceSession(model_path)
        sess_inputs = [inp.name for inp in sess.get_inputs()]
        sess_outputs = sess.run(
            self.get_output_names(),
            self.get_sess_inputs(),
        )

        for i, out in enumerate(
            self.get_expected_outputs(),
        ):
            assert np.allclose(out, sess_outputs[i])
        print("Output is the same as expected. Test passed.")
        return sess_outputs

    def model_to_txt(self, path: Optional[Path | str] = None):
        """load the generated model and save it to a txt file for debugging purposes.

        Args:
            path (Optional[Path  |  str], optional): desired path to the output. if unspecified, defaults to {op_name}.txt
            in the current directory. Defaults to None.

        Raises:
            ValueError: If you provide a path and it doesn't end with .txt then an error is raised
        """
        model = onnx.helper.make_model(self.make_onnx_graph())
        out_path = _get_path(self.graph_name, path, ".txt")
        with open(out_path, "w") as f:
            f.write(str(model))
        print(f"Model saved to {out_path}")


if __name__ == "__main__":
    const_axes = [0, 4]
    axis = [1]
    x = np.array(np.random.randn(3, 4, 5))
    y = np.expand_dims(x, axis=const_axes)
    z = np.expand_dims(y, axis=axis)

    if y.shape != (1, 3, 4, 5, 1):
        raise ValueError(f"Expected shape (1,3,4,5,1) but got {y.shape}")
    if z.shape != (1, 1, 3, 4, 5, 1):
        raise ValueError(f"Expected shape (1,1,3,4,5,1) but got {z.shape}")

    data = OnnxGraphBuilder(
        "Unsqueeze",
        {"x": x, "axes": const_axes},
        {"y": y},
        rhs_constant=True,
    )
    data.add_node("Unsqueeze", {"y": "y", "axis": axis}, {"z": z})

    # data.model_to_txt()
    result = data.validate_model()

    assert np.allclose(result[0], data.get_expected_outputs()[0])
    print("Test passed")
